{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "### Crossover算子\n",
    "值得一提的是，DEAP中GP默认实现的Crossover算子不考虑根节点。因此，如果要按照GP的原始论文实现，需要稍作修改。"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "8db4ada5ce6ebf73"
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "import numpy as np\n",
    "from deap import base, creator, tools, gp\n",
    "\n",
    "\n",
    "# 符号回归\n",
    "def evalSymbReg(individual, pset):\n",
    "    # 编译GP树为函数\n",
    "    func = gp.compile(expr=individual, pset=pset)\n",
    "    \n",
    "    # 使用numpy创建一个向量\n",
    "    x = np.linspace(-10, 10, 100) \n",
    "    \n",
    "    # 评估生成的函数并计算MSE\n",
    "    mse = np.mean((func(x) - x**2)**2)\n",
    "    \n",
    "    return (mse,)\n",
    "\n",
    "# 创建个体和适应度函数\n",
    "creator.create(\"FitnessMin\", base.Fitness, weights=(-1.0,))\n",
    "creator.create(\"Individual\", gp.PrimitiveTree, fitness=creator.FitnessMin)\n",
    "\n"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-11-07T08:49:09.672369400Z",
     "start_time": "2023-11-07T08:49:09.564823400Z"
    }
   },
   "id": "initial_id"
  },
  {
   "cell_type": "markdown",
   "source": [
    "具体来说，需要修改交叉点的取值范围，以包括根节点。"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "e3d94e424b58af5a"
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "outputs": [],
   "source": [
    "\n",
    "from collections import defaultdict\n",
    "\n",
    "__type__ = object\n",
    "\n",
    "def cxOnePoint(ind1, ind2):\n",
    "    # List all available primitive types in each individual\n",
    "    types1 = defaultdict(list)\n",
    "    types2 = defaultdict(list)\n",
    "    if ind1.root.ret == __type__:\n",
    "        # Not STGP optimization\n",
    "        types1[__type__] = list(range(0, len(ind1)))\n",
    "        types2[__type__] = list(range(0, len(ind2)))\n",
    "        common_types = [__type__]\n",
    "    else:\n",
    "        for idx, node in enumerate(ind1[0:], 1):\n",
    "            types1[node.ret].append(idx)\n",
    "        for idx, node in enumerate(ind2[0:], 1):\n",
    "            types2[node.ret].append(idx)\n",
    "        common_types = set(types1.keys()).intersection(set(types2.keys()))\n",
    "\n",
    "    if len(common_types) > 0:\n",
    "        type_ = random.choice(list(common_types))\n",
    "\n",
    "        index1 = random.choice(types1[type_])\n",
    "        index2 = random.choice(types2[type_])\n",
    "\n",
    "        slice1 = ind1.searchSubtree(index1)\n",
    "        slice2 = ind2.searchSubtree(index2)\n",
    "        ind1[slice1], ind2[slice2] = ind2[slice2], ind1[slice1]\n",
    "\n",
    "    return ind1, ind2"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-11-07T08:49:09.678933300Z",
     "start_time": "2023-11-07T08:49:09.675377100Z"
    }
   },
   "id": "5dde655dc691a423"
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\zhenl\\anaconda3\\Lib\\site-packages\\deap\\gp.py:254: RuntimeWarning: Ephemeral rand101 function cannot be pickled because its generating function is a lambda function. Use functools.partial instead.\n",
      "  warnings.warn(\"Ephemeral {name} function cannot be \"\n"
     ]
    }
   ],
   "source": [
    "import random\n",
    "\n",
    "# 定义函数集合和终端集合\n",
    "pset = gp.PrimitiveSet(\"MAIN\", arity=1)\n",
    "pset.addPrimitive(np.add, 2)\n",
    "pset.addPrimitive(np.subtract, 2)\n",
    "pset.addPrimitive(np.multiply, 2)\n",
    "pset.addPrimitive(np.negative, 1)\n",
    "pset.addEphemeralConstant(\"rand101\", lambda: random.randint(-1, 1))\n",
    "pset.renameArguments(ARG0='x')\n",
    "\n",
    "# 定义遗传编程操作\n",
    "toolbox = base.Toolbox()\n",
    "toolbox.register(\"expr\", gp.genHalfAndHalf, pset=pset, min_=1, max_=2)\n",
    "toolbox.register(\"individual\", tools.initIterate, creator.Individual, toolbox.expr)\n",
    "toolbox.register(\"population\", tools.initRepeat, list, toolbox.individual)\n",
    "toolbox.register(\"compile\", gp.compile, pset=pset)\n",
    "toolbox.register(\"evaluate\", evalSymbReg, pset=pset)\n",
    "toolbox.register(\"select\", tools.selTournament, tournsize=3)\n",
    "toolbox.register(\"mate\", cxOnePoint)\n",
    "toolbox.register(\"mutate\", gp.mutUniform, expr=toolbox.expr, pset=pset)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-11-07T08:49:09.694753600Z",
     "start_time": "2023-11-07T08:49:09.680991300Z"
    }
   },
   "id": "cb6cf38094256262"
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "   \t      \t                    fitness                    \t                     size                    \n",
      "   \t      \t-----------------------------------------------\t---------------------------------------------\n",
      "gen\tnevals\tavg    \tgen\tmax   \tmin\tnevals\tstd    \tavg    \tgen\tmax\tmin\tnevals\tstd  \n",
      "0  \t300   \t2715.84\t0  \t159956\t0  \t300   \t9168.43\t4.18333\t0  \t7  \t2  \t300   \t1.799\n",
      "1  \t172   \t44557.2\t1  \t1.23323e+07\t0  \t172   \t710775 \t4.69333\t1  \t12 \t1  \t172   \t2.49519\n",
      "2  \t178   \t83962.7\t2  \t1.23323e+07\t0  \t178   \t979048 \t5.16   \t2  \t15 \t1  \t178   \t2.89846\n",
      "3  \t186   \t4491.27\t3  \t155827     \t0  \t186   \t21401.5\t5.81667\t3  \t20 \t1  \t186   \t3.56787\n",
      "4  \t182   \t6884.08\t4  \t608604     \t0  \t182   \t41827.6\t6.56333\t4  \t19 \t1  \t182   \t3.67324\n",
      "5  \t199   \t49752.6\t5  \t1.24881e+07\t0  \t199   \t720046 \t6.66   \t5  \t19 \t1  \t199   \t3.76533\n",
      "6  \t170   \t9782.06\t6  \t155827     \t0  \t170   \t35314.4\t6.74667\t6  \t19 \t1  \t170   \t3.74644\n",
      "7  \t190   \t52741.3\t7  \t1.23323e+07\t0  \t190   \t711213 \t6.83   \t7  \t27 \t1  \t190   \t4.39634\n",
      "8  \t185   \t15256.6\t8  \t159956     \t0  \t185   \t44576.8\t6.69   \t8  \t19 \t1  \t185   \t4.3788 \n",
      "9  \t188   \t57930.5\t9  \t1.17257e+07\t0  \t188   \t677238 \t7.5    \t9  \t20 \t1  \t188   \t4.68295\n",
      "10 \t169   \t48659  \t10 \t1.18774e+07\t0  \t169   \t684937 \t7.67   \t10 \t24 \t1  \t169   \t5.1037 \n",
      "11 \t190   \t12291.9\t11 \t159956     \t0  \t190   \t40001.2\t8.23   \t11 \t34 \t1  \t190   \t5.72571\n",
      "12 \t200   \t3.36087e+06\t12 \t1.00337e+09\t0  \t200   \t5.78318e+07\t9.33333\t12 \t31 \t1  \t200   \t6.2872 \n",
      "13 \t181   \t3.51385e+06\t13 \t1.00337e+09\t0  \t181   \t5.78881e+07\t10.81  \t13 \t32 \t1  \t181   \t6.99099\n",
      "14 \t164   \t166673     \t14 \t1.21786e+07\t0  \t164   \t1.37053e+06\t11.79  \t14 \t32 \t1  \t164   \t7.3298 \n",
      "15 \t194   \t3.35859e+06\t15 \t1.00337e+09\t0  \t194   \t5.78319e+07\t13.3433\t15 \t34 \t1  \t194   \t7.65586\n",
      "16 \t186   \t3.60051e+06\t16 \t1.00306e+09\t0  \t186   \t5.78788e+07\t13.9733\t16 \t38 \t1  \t186   \t8.20524\n",
      "17 \t186   \t209138     \t17 \t1.24839e+07\t0  \t186   \t1.55124e+06\t15.67  \t17 \t39 \t1  \t186   \t8.63449\n",
      "18 \t203   \t132987     \t18 \t1.40065e+07\t0  \t203   \t1.25645e+06\t17.1267\t18 \t52 \t1  \t203   \t9.00207\n",
      "19 \t166   \t3.09085e+08\t19 \t9.06199e+10\t0  \t166   \t5.22347e+09\t18.5133\t19 \t53 \t1  \t166   \t9.84157\n",
      "20 \t177   \t2.57857e+10\t20 \t7.64709e+12\t0  \t177   \t4.40781e+11\t19.0333\t20 \t61 \t1  \t177   \t10.3691\n",
      "21 \t183   \t2.34401e+12\t21 \t7.03113e+14\t0  \t183   \t4.05265e+13\t20.6633\t21 \t68 \t1  \t183   \t11.2997\n",
      "22 \t190   \t4.62999e+09\t22 \t1.38492e+12\t0  \t190   \t7.98244e+10\t22.0267\t22 \t68 \t1  \t190   \t11.341 \n",
      "23 \t178   \t3.50733e+06\t23 \t1.01539e+09\t0  \t178   \t5.85307e+07\t25.86  \t23 \t81 \t1  \t178   \t12.9087\n",
      "24 \t171   \t3.31331e+06\t24 \t9.67885e+08\t0  \t171   \t5.57919e+07\t27.1133\t24 \t81 \t3  \t171   \t13.4261\n",
      "25 \t176   \t1.13102e+09\t25 \t3.38252e+11\t0  \t176   \t1.94963e+10\t29.55  \t25 \t91 \t1  \t176   \t15.8352\n",
      "26 \t185   \t2.40453e+10\t26 \t7.21356e+12\t0  \t185   \t4.1578e+11 \t29.84  \t26 \t106\t1  \t185   \t17.9048\n",
      "27 \t165   \t3.55088e+06\t27 \t9.91189e+08\t0  \t165   \t5.71939e+07\t30.8433\t27 \t108\t1  \t165   \t18.9858\n",
      "28 \t185   \t7.13202e+06\t28 \t1.05147e+09\t0  \t185   \t8.36634e+07\t30.7833\t28 \t108\t1  \t185   \t20.1121\n",
      "29 \t169   \t2.75476e+08\t29 \t8.25924e+10\t0  \t169   \t4.76051e+09\t34.4967\t29 \t109\t1  \t169   \t22.0248\n",
      "30 \t177   \t2.88609e+08\t30 \t8.65576e+10\t0  \t177   \t4.98906e+09\t36.1267\t30 \t118\t1  \t177   \t23.8665\n",
      "31 \t182   \t1739.48    \t31 \t153712     \t0  \t182   \t15291.9    \t37.52  \t31 \t128\t1  \t182   \t24.7013\n",
      "32 \t195   \t1.67291e+07\t32 \t4.01346e+09\t0  \t195   \t2.37767e+08\t40.9167\t32 \t224\t1  \t195   \t26.6789\n",
      "33 \t173   \t2.8527e+08 \t33 \t8.55423e+10\t0  \t173   \t4.93054e+09\t43.31  \t33 \t224\t1  \t173   \t32.4   \n",
      "34 \t185   \t1.20699e+11\t34 \t2.85626e+13\t0  \t185   \t1.70287e+12\t47.33  \t34 \t224\t1  \t185   \t36.9719\n",
      "35 \t179   \t5.83821e+08\t35 \t8.75609e+10\t0  \t179   \t7.08471e+09\t47.9867\t35 \t224\t1  \t179   \t36.5101\n",
      "36 \t158   \t5.84712e+08\t36 \t8.76092e+10\t0  \t158   \t7.08668e+09\t54.2833\t36 \t227\t1  \t158   \t40.1071\n",
      "37 \t179   \t3.42709e+06\t37 \t1.01539e+09\t0  \t179   \t5.85275e+07\t54.31  \t37 \t227\t1  \t179   \t42.2209\n",
      "38 \t172   \t1.01978e+11\t38 \t3.05924e+13\t0  \t172   \t1.7633e+12 \t58.2333\t38 \t227\t1  \t172   \t41.8791\n",
      "39 \t199   \t5314.13    \t39 \t606522     \t0  \t199   \t40830.6    \t63.55  \t39 \t255\t1  \t199   \t50.0976\n",
      "40 \t180   \t44321.1    \t40 \t1.20269e+07\t0  \t180   \t694103     \t65.0333\t40 \t261\t1  \t180   \t50.8197\n",
      "41 \t186   \t2.26667e+12\t41 \t6.8e+14    \t0  \t186   \t3.91943e+13\t65.24  \t41 \t259\t1  \t186   \t52.4852\n",
      "42 \t161   \t3.2596e+08 \t42 \t9.47769e+10\t0  \t161   \t5.46315e+09\t69.54  \t42 \t334\t1  \t161   \t58.7825\n",
      "43 \t181   \t1.72076e+07\t43 \t4.15839e+09\t0  \t181   \t2.46211e+08\t75.62  \t43 \t340\t1  \t181   \t64.9814\n",
      "44 \t178   \t5.63796e+08\t44 \t8.45629e+10\t0  \t178   \t6.88147e+09\t76.7633\t44 \t407\t1  \t178   \t66.4288\n",
      "45 \t199   \t2.3698e+12 \t45 \t7.1076e+14 \t0  \t199   \t4.09673e+13\t81.69  \t45 \t407\t5  \t199   \t67.4995\n",
      "46 \t177   \t2.54942e+10\t46 \t7.64309e+12\t0  \t177   \t4.40537e+11\t81.87  \t46 \t403\t3  \t177   \t66.3547\n",
      "47 \t188   \t1.08948e+09\t47 \t3.26818e+11\t0  \t188   \t1.88374e+10\t82.7633\t47 \t357\t1  \t188   \t62.0543\n",
      "48 \t173   \t1.42365e+07\t48 \t4.27076e+09\t0  \t173   \t2.46161e+08\t86.82  \t48 \t375\t1  \t173   \t63.6163\n",
      "49 \t186   \t1186.38    \t49 \t153712     \t0  \t186   \t12514.9    \t87.6367\t49 \t343\t1  \t186   \t60.8458\n",
      "50 \t165   \t84904      \t50 \t1.29305e+07\t0  \t165   \t1.01613e+06\t92.9433\t50 \t276\t1  \t165   \t63.3678\n",
      "time: 2.3149585723876953\n",
      "multiply(x, x)\n"
     ]
    }
   ],
   "source": [
    "import numpy\n",
    "from deap import algorithms\n",
    "\n",
    "# 定义统计指标\n",
    "stats_fit = tools.Statistics(lambda ind: ind.fitness.values)\n",
    "stats_size = tools.Statistics(len)\n",
    "mstats = tools.MultiStatistics(fitness=stats_fit, size=stats_size)\n",
    "mstats.register(\"avg\", numpy.mean)\n",
    "mstats.register(\"std\", numpy.std)\n",
    "mstats.register(\"min\", numpy.min)\n",
    "mstats.register(\"max\", numpy.max)\n",
    "\n",
    "# 使用默认算法\n",
    "start=time.time()\n",
    "population = toolbox.population(n=300)\n",
    "hof = tools.HallOfFame(1)\n",
    "pop, log  = algorithms.eaSimple(population=population,\n",
    "                           toolbox=toolbox, cxpb=0.5, mutpb=0.2, ngen=50, stats=mstats, halloffame=hof, verbose=True)\n",
    "end=time.time()\n",
    "print('time:',end-start)\n",
    "print(str(hof[0]))"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-11-07T08:49:12.030799500Z",
     "start_time": "2023-11-07T08:49:09.694753600Z"
    }
   },
   "id": "88c62bc071d56191"
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
